import os
import json

import numpy as np
import pandas as pd

def print_tex_table(data):
    dim_1, dim_2 = data.shape
    for m in range(dim_1):
        for n in range(dim_2):
            if n != 0:
                print(" & ", end = "")
            print("%.2f"%data[m][n], end="")
            if n == (dim_2 - 1):
                print(" \\\\")

home_dir = os.path.expanduser('~')
base_dir = os.path.join(home_dir, "Documents", "results", "3stream", "classification")

repository_base = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
base_dir = os.path.join(repository_base, "results")

sessions = list()
sessions.append("sub-A")
sessions.append("sub-B")
sessions.append("sub-C")
sessions.append("sub-D")
sessions.append("sub-E")
sessions.append("sub-F")
sessions.append("sub-G")
sessions.append("sub-H")
sessions.append("sub-I")
sessions.append("sub-J")
sessions.append("sub-K")

data = list()
data_mcc = list()
for session in sessions:
    sub_data = list()
    sub_data_mcc = list()
    with open(os.path.join(base_dir, "%s_classification_scores.json"%session), "r") as f:
        json_data = json.load(f)
    for m in range(1, 4):
        sub_data.append(json_data['%d'%m]['report']['accuracy'])
        sub_data_mcc.append(json_data['%d'%m]['mcc'])
    data.append(sub_data)
    data_mcc.append(sub_data_mcc)

# accuracy scores
data = np.array(data)
data = np.hstack((data, np.mean(data, axis = 1).reshape(-1, 1)))
data = np.vstack((data, np.mean(data, axis = 0).reshape(1, -1)))
#print("== acc ==")
#print_tex_table(data)

# mcc scores
data_mcc = np.array(data_mcc)
data_mcc = np.hstack((data_mcc, np.mean(data_mcc, axis = 1).reshape(-1, 1)))
data_mcc = np.vstack((data_mcc, np.mean(data_mcc, axis = 0).reshape(1, -1)))
#print("== mcc ==")
#print_tex_table(data_mcc)

# compiled
comp = np.empty((12,0))
for m in range(4):
    comp = np.hstack((comp, np.atleast_2d(data[:,m]).T))
    comp = np.hstack((comp, np.atleast_2d(data_mcc[:,m]).T))

#print(comp)
#print("== compiled ==")
#print_tex_table(comp)

index = sessions.copy()
index.append('Avg')
df = pd.DataFrame(data=comp, columns=['Stream1_acc', 'Stream1_mcc', 'Stream2_acc', 'Stream2_mcc', 'Stream3_acc', 'Stream3_mcc', 'avg_acc', 'avg_mcc'],
                  index = index)
print(df)
